{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mI6aVBQRy6LC"
      },
      "source": [
        "# 3️⃣ Combining Pre-Trained Adapters using AdapterFusion\n",
        "\n",
        "In [the previous notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/master/notebooks/02_Adapter_Inference.ipynb), we loaded a single pre-trained adapter from _AdapterHub_. Now we will explore how to take advantage of multiple pre-trained adapters to combine their knowledge on a new task. Combining multiple adapters together into one 'block' is called an 'adapter composition'. In this notebook, we will explain one such block known as **AdapterFusion** ([Pfeiffer et al., 2020](https://arxiv.org/pdf/2005.00247.pdf)).\n",
        "\n",
        "For this guide, we select **CommitmentBank** ([De Marneffe et al., 2019](https://github.com/mcdm/CommitmentBank)), a three-class textual entailment dataset, as our target task. We will fuse [adapters from AdapterHub](https://adapterhub.ml/explore/) which were pre-trained on different tasks. During training, their representations are kept fix while a newly introduced fusion layer is trained. As our base model, we will use BERT (`bert-base-uncased`)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3L9gYpCV28OA"
      },
      "source": [
        "## Installation\n",
        "\n",
        "Again, we install `adapters` and HuggingFace's `datasets` library first:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T12:14:47.927413900Z",
          "start_time": "2023-08-17T12:14:42.946988300Z"
        },
        "id": "qL3Sq1HQynCq"
      },
      "outputs": [],
      "source": [
        "!pip install -Uq adapters\n",
        "!pip install -q datasets\n",
        "!pip install -q accelerate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fP9OMSUT-FtL"
      },
      "source": [
        "## Dataset Preprocessing\n",
        "\n",
        "Before setting up training, we first prepare the training data. CommimentBank is part of the SuperGLUE benchmark and can be loaded via HuggingFace `datasets` using one line of code:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:10:12.742774300Z",
          "start_time": "2023-08-17T09:10:06.346891800Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "INW7UEhC-I6b",
        "outputId": "12f3b8a8-5960-4eb2-d369-669a785d4b08"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'train': 250, 'validation': 56, 'test': 250}"
            ]
          },
          "execution_count": 1,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "dataset = load_dataset(\"super_glue\", \"cb\")\n",
        "dataset.num_rows"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "epiKaEz5dDVe"
      },
      "source": [
        "Every dataset sample has a premise, a hypothesis and a three-class class label:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:10:20.713443200Z",
          "start_time": "2023-08-17T09:10:20.698484900Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ifu4q5IJ-hYI",
        "outputId": "fcbcf0e4-1134-421b-ec54-6d0182bc758d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'premise': Value(dtype='string', id=None),\n",
              " 'hypothesis': Value(dtype='string', id=None),\n",
              " 'idx': Value(dtype='int32', id=None),\n",
              " 'label': ClassLabel(names=['entailment', 'contradiction', 'neutral'], id=None)}"
            ]
          },
          "execution_count": 2,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset['train'].features"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uVa3Vk0QdNYI"
      },
      "source": [
        "Now, we need to encode all dataset samples to valid inputs for our `bert-base-uncased` model. Using `dataset.map()`, we can pass the full dataset through the tokenizer in batches:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:10:30.595995300Z",
          "start_time": "2023-08-17T09:10:26.589718800Z"
        },
        "id": "hEnRCQfE_Oi3"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "d:\\Anaconda\\envs\\adapters_env\\Lib\\site-packages\\huggingface_hub\\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "from transformers import BertTokenizer\n",
        "\n",
        "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
        "\n",
        "def encode_batch(batch):\n",
        "  \"\"\"Encodes a batch of input data using the model tokenizer.\"\"\"\n",
        "  return tokenizer(\n",
        "      batch[\"premise\"],\n",
        "      batch[\"hypothesis\"],\n",
        "      max_length=180,\n",
        "      truncation=True,\n",
        "      padding=\"max_length\"\n",
        "  )\n",
        "\n",
        "# Encode the input data\n",
        "dataset = dataset.map(encode_batch, batched=True)\n",
        "# The transformers model expects the target class column to be named \"labels\"\n",
        "dataset = dataset.rename_column(\"label\", \"labels\")\n",
        "# Transform to pytorch tensors and only output the required columns\n",
        "dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TsXEy1TNeAnI"
      },
      "source": [
        "New we're ready to setup AdapterFusion..."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xs21MzEQ_0v4"
      },
      "source": [
        "## Fusion Training\n",
        "\n",
        "We use a pre-trained BERT model from HuggingFace and instantiate our model using `BertAdapterModel`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:11:01.706736300Z",
          "start_time": "2023-08-17T09:10:59.697115300Z"
        },
        "id": "fnq8n_KP_3aX"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "d:\\Anaconda\\envs\\adapters_env\\Lib\\site-packages\\huggingface_hub\\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
            "  warnings.warn(\n",
            "Some weights of BertAdapterModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['heads.default.3.bias']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        }
      ],
      "source": [
        "from transformers import BertConfig\n",
        "from adapters import BertAdapterModel\n",
        "\n",
        "id2label = {id: label for (id, label) in enumerate(dataset[\"train\"].features[\"labels\"].names)}\n",
        "\n",
        "config = BertConfig.from_pretrained(\n",
        "    \"bert-base-uncased\",\n",
        "    id2label=id2label,\n",
        ")\n",
        "model = BertAdapterModel.from_pretrained(\n",
        "    \"bert-base-uncased\",\n",
        "    config=config,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BDCNJzTXezcn"
      },
      "source": [
        "Now we have everything set up to load our _AdapterFusion_ setup. First, we load three adapters pre-trained on different tasks from the Hub: MultiNLI, QQP and QNLI. As we don't need their prediction heads, we pass `with_head=False` to the loading method. Next, we add a new fusion layer that combines all the adapters we've just loaded. Finally, we add a new classification head for our target task on top.\n",
        "\n",
        "We can define a fusion layer by adding a `Fuse` block from the `composition` module. The `Fuse` block is a method of combining multiple pre-trained adapters for a new downstream task. Just like `add_adapter` from the previous notebooks, the method `add_adapter_fusion` introduces an untrained fusion layer with randomly initialized weights. The weights of the `Fuse` block then get updated when training the model through the dataset.\n",
        "\n",
        "To learn more about `AdapterFusion` you can check out: https://docs.adapterhub.ml/adapter_composition.html#fuse"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:11:36.098695Z",
          "start_time": "2023-08-17T09:11:27.950502200Z"
        },
        "id": "jRqbBgS0BoHJ"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Automatic redirect to HF Model Hub repo 'AdapterHub/bert-base-uncased_nli_multinli_pfeiffer'. Please switch to the new ID to remove this warning.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ec311abc431b429fa2fc9eb3bff51506",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Automatic redirect to HF Model Hub repo 'AdapterHub/bert-base-uncased_sts_qqp_pfeiffer'. Please switch to the new ID to remove this warning.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a8882b442f0d44f49cb427f867847eaa",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Automatic redirect to HF Model Hub repo 'AdapterHub/bert-base-uncased_nli_qnli_pfeiffer'. Please switch to the new ID to remove this warning.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6368ba958cb342b690a20624effcb08c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "d:\\Anaconda\\envs\\adapters_env\\Lib\\site-packages\\huggingface_hub\\file_download.py:157: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Jackson\\.cache\\huggingface\\hub\\models--AdapterHub--bert-base-uncased_nli_qnli_pfeiffer. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
            "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
            "  warnings.warn(message)\n"
          ]
        }
      ],
      "source": [
        "from adapters.composition import Fuse\n",
        "\n",
        "# Load the pre-trained adapters we want to fuse\n",
        "model.load_adapter(\"nli/multinli@ukp\", load_as=\"multinli\", with_head=False)\n",
        "model.load_adapter(\"sts/qqp@ukp\", with_head=False)\n",
        "model.load_adapter(\"nli/qnli@ukp\", with_head=False)\n",
        "# Add a fusion layer for all loaded adapters\n",
        "adapter_setup = Fuse(\"multinli\", \"qqp\", \"qnli\")\n",
        "model.add_adapter_fusion(adapter_setup)\n",
        "\n",
        "# Add a classification head for our target task\n",
        "model.add_classification_head(\"cb\", num_labels=len(id2label))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "60qIas8-il92"
      },
      "source": [
        "The last preparation step train our adapter setup. Similar to `train_adapter()`, `train_adapter_fusion()` does two things: It freezes all weights of the model (including adapters!) except for the fusion layer and classification head. It also activates the given adapter setup to be used in very forward pass."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:12:05.113045600Z",
          "start_time": "2023-08-17T09:12:05.083127700Z"
        },
        "id": "zgGqHJQbijgg"
      },
      "outputs": [],
      "source": [
        "# Unfreeze and activate fusion setup\n",
        "model.train_adapter_fusion(adapter_setup)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fb8BY5RAmzkd"
      },
      "source": [
        "For training, we make use of the `AdapterTrainer` class built-in into `adapters`. We configure the training process using a `TrainingArguments` object and define a method that will calculate the evaluation accuracy in the end. We pass both, together with the training and validation split of our dataset, to the trainer instance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:12:34.639028100Z",
          "start_time": "2023-08-17T09:12:34.319882200Z"
        },
        "id": "j0gFxQRdDkQ6"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from transformers import TrainingArguments, EvalPrediction\n",
        "from adapters import AdapterTrainer\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    learning_rate=5e-5,\n",
        "    num_train_epochs=5,\n",
        "    per_device_train_batch_size=32,\n",
        "    per_device_eval_batch_size=32,\n",
        "    logging_steps=200,\n",
        "    output_dir=\"./training_output\",\n",
        "    overwrite_output_dir=True,\n",
        "    # The next line is important to ensure the dataset labels are properly passed to the model\n",
        "    remove_unused_columns=False,\n",
        ")\n",
        "\n",
        "def compute_accuracy(p: EvalPrediction):\n",
        "  preds = np.argmax(p.predictions, axis=1)\n",
        "  return {\"acc\": (preds == p.label_ids).mean()}\n",
        "\n",
        "trainer = AdapterTrainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=dataset[\"train\"],\n",
        "    eval_dataset=dataset[\"validation\"],\n",
        "    compute_metrics=compute_accuracy,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iKlYjA9rm2Kp"
      },
      "source": [
        "Start the training 🚀 (this will take a while)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:13:43.869798100Z",
          "start_time": "2023-08-17T09:12:46.140248700Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 165
        },
        "id": "vSUs2FjXDmsx",
        "outputId": "ab2ca72e-d617-4e64-9ea4-49fd41a2ac9f"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d0b7a34d62ab44689f50f25f9912915a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  0%|          | 0/40 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'train_runtime': 332.7348, 'train_samples_per_second': 3.757, 'train_steps_per_second': 0.12, 'train_loss': 0.7115382671356201, 'epoch': 5.0}\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=40, training_loss=0.7115382671356201, metrics={'train_runtime': 332.7348, 'train_samples_per_second': 3.757, 'train_steps_per_second': 0.12, 'total_flos': 149577058350000.0, 'train_loss': 0.7115382671356201, 'epoch': 5.0})"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "axvDsmnJnGUG"
      },
      "source": [
        "After completed training, let's check how well our setup performs on the validation set of our target dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:13:50.616742100Z",
          "start_time": "2023-08-17T09:13:47.526014400Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 141
        },
        "id": "PqcApZ_-DpwK",
        "outputId": "f83a72f6-bd16-4d87-f6e1-f39aafee2917"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b0dba7265ff94dcb943d43f30a90d525",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "{'eval_loss': 0.6307274699211121,\n",
              " 'eval_acc': 0.7678571428571429,\n",
              " 'eval_runtime': 10.3644,\n",
              " 'eval_samples_per_second': 5.403,\n",
              " 'eval_steps_per_second': 0.193,\n",
              " 'epoch': 5.0}"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.evaluate()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "op8M7AfYnWhs"
      },
      "source": [
        "We can also use our setup to make some predictions (the example is from the test set of CB):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:19:10.426805700Z",
          "start_time": "2023-08-17T09:19:10.185453300Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "e9wQcaNLNPAT",
        "outputId": "28fcbc78-0005-4591-d944-ea73c906f273"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'contradiction'"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import torch\n",
        "\n",
        "def predict(premise, hypothesis):\n",
        "  encoded = tokenizer(premise, hypothesis, return_tensors=\"pt\")\n",
        "  if torch.cuda.is_available():\n",
        "    encoded.to(\"cuda\")\n",
        "  logits = model(**encoded)[0]\n",
        "  pred_class = torch.argmax(logits).item()\n",
        "  return id2label[pred_class]\n",
        "\n",
        "predict(\"\"\"\n",
        "``It doesn't happen very often.'' Karen went home\n",
        "happy at the end of the day. She didn't think that\n",
        "the work was difficult.\n",
        "\"\"\",\n",
        "\"the work was difficult\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H-fJK8OtnrAu"
      },
      "source": [
        "Finally, we can extract and save our fusion layer as well as all the adapters we used for training. Both can later be reloaded into the pre-trained model again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2023-08-17T09:21:24.639621300Z",
          "start_time": "2023-08-17T09:21:24.413228700Z"
        },
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YpHIRIHxJd5d",
        "outputId": "f6a7ed99-dc46-4331-c7d6-907508182ad0"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "'ls' is not recognized as an internal or external command,\n",
            "operable program or batch file.\n"
          ]
        }
      ],
      "source": [
        "model.save_adapter_fusion(\"./saved\", \"multinli,qqp,qnli\")\n",
        "model.save_all_adapters(\"./saved\")\n",
        "\n",
        "!ls -l saved"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4PHWpVLApDWt"
      },
      "source": [
        "That's it. Do check out [the paper on AdapterFusion](https://arxiv.org/pdf/2005.00247.pdf) for a more theoretical view on what we've just seen.\n",
        "\n",
        "➡️ `adapters` also enables other composition methods beyond AdapterFusion. For example, check out [the next notebook in this series](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/master/notebooks/04_Cross_Lingual_Transfer.ipynb) on cross-lingual transfer."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "env",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.4"
    },
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "metadata": {
          "collapsed": false
        },
        "source": []
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
